import pandas as pd
import numpy as np
import start


# %%

gold_standard = pd.read_excel(
    start.MAIN_DIR + "data/clean/character_classifications_gold.xlsx"
)

# %%

sets = ["train", "dev", "test"]
categories = ["hero", "villain", "victim", "other"]
counts = {set_name: [] for set_name in sets}

for set_name in sets:
    df_set = gold_standard[gold_standard["set"] == set_name]
    for category in categories:
        counts[set_name].append(df_set.character_gold.value_counts().get(category, 0))

total_counts = [
    sum(counts[set_name][i] for set_name in sets) for i in range(len(categories))
]

summary_table = pd.DataFrame(counts, index=categories)
summary_table["Total"] = total_counts
summary_table.loc["Total"] = summary_table.sum()

print(summary_table)
summary_table.to_excel(start.MAIN_DIR + "results/gold_standard_character_counts.xlsx")
# %%
